from transforms import get_transforms
from torch.utils.data import Dataset, DataLoader
import wandb
import torch
from utils import upscale_tensor, find_first_checkpoint
from collections import OrderedDict


class Support_Vectors(Dataset):
    def __init__(self, REF_MODEL_PATH, SUPPORT_PATH, config, is_intersec = False):
        from models import LitResnet
        SUPPORT_PATH = find_first_checkpoint(SUPPORT_PATH)

        tensor_state_dict = torch.load(SUPPORT_PATH)['state_dict']
        self.tensors = tensor_state_dict['trainable_params_1']
        self.tensors_lambda = tensor_state_dict['trainable_lambda']
        num_samples = len(self.tensors_lambda)//config['NUM_CLASSES']
        num_classes = config['NUM_CLASSES']
        get_class = lambda index: index // num_samples
        self.targets = torch.arange(0, config['NUM_CLASSES']).unsqueeze(1).repeat(1, num_samples).flatten().to("cuda:0")
        frozen_net = LitResnet(config)
        new_state_dict = OrderedDict()
        for k, v in torch.load(REF_MODEL_PATH)['state_dict'].items():
            name = k[6:]
            new_state_dict[name] = v
        frozen_net.model.load_state_dict(new_state_dict)
        model = frozen_net.cuda()
        if self.tensors.shape[-1] < 32:
            self.tensors = upscale_tensor(self.tensors, 32 // self.tensors.shape[-1])
        self.outputs = model(self.tensors)
        self.outputs = self.outputs.softmax(dim=-1)
        
        
        
    def __len__(self):
        return self.len
        # return len(self.tensors)

    def __getitem__(self, idx):
        return self.tensors[idx], self.targets[idx], self.outputs[idx] , self.outputs[idx]
